Maximum product of splitted binary tree [DFS]¶
Time: O(N); Space: O(H); medium
Given a binary tree root. Split the binary tree into two subtrees by removing 1 edge such that the product of the sums of the subtrees are maximized.
Since the answer may be too large, return it modulo 10^9 + 7.
Example 1:
Input: root = {TreeNode} [1,2,3,4,5,6]
Output: 110
Explanation:
Remove the red edge and get 2 binary trees with sum 11 and 10. Their product is 110 (11*10)
Example 2:
Input: root = {TreeNode} [1,null,2,3,4,null,null,5,6]
Output: 90
Explanation:
Remove the red edge and get 2 binary trees with sum 15 and 6.Their product is 90 (15*6)
Example 3:
Input: root = {TreeNode} [2,3,9,10,7,8,6,5,4,11,1]
Output: 1025
Example 4:
Input: root = {TreeNode} [1,1]
Output: 1
Constraints:
Each tree has at most 50000 nodes and at least 2 nodes.
Each node’s value is between [1, 10000].
Hints:
If we know the sum of a subtree, the answer is max( (total_sum - subtree_sum) * subtree_sum) in each node.
[6]:
class TreeNode(object):
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
1. Depth First Search¶
[7]:
class Solution1(object):
"""
Time: O(N)
Space: O(H)
"""
def maxProduct(self, root):
"""
:type root: TreeNode
:rtype: int
"""
MOD = 10**9 + 7
def dfs(root, total, result):
if not root:
return 0
subtotal = dfs(root.left, total, result) + \
dfs(root.right, total, result) + \
root.val
result[0] = max(result[0], subtotal * (total - subtotal))
return subtotal
result = [0]
dfs(root, dfs(root, 0, result), result)
return result[0] % MOD
[8]:
s = Solution1()
root = TreeNode(1)
root.left, root.right = TreeNode(2), TreeNode(3)
root.left.left, root.left.right = TreeNode(4), TreeNode(5)
root.right.left = TreeNode(6)
assert s.maxProduct(root) == 110
root = TreeNode(1)
root.right = TreeNode(2)
root.right.left, root.right.right = TreeNode(3), TreeNode(4)
root.right.right.left, root.right.right.right = TreeNode(5), TreeNode(6)
assert s.maxProduct(root) == 90
root = TreeNode(2)
root.left, root.right = TreeNode(3), TreeNode(9)
root.left.left, root.left.right = TreeNode(10), TreeNode(7)
root.right.left, root.right.right = TreeNode(8), TreeNode(6)
root.left.left.left, root.left.left.right = TreeNode(5), TreeNode(4)
root.left.right.left, root.left.right.right = TreeNode(11), TreeNode(1)
assert s.maxProduct(root) == 1025
root = TreeNode(1)
root.left = TreeNode(1)
assert s.maxProduct(root) == 1